
import numpy as np
from channel import *
from eq import *
from simNeuralEQ import *
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"


class simSweep:
	''' 
	Description:
		Simulate selected EQ for given SNR. 
		SNR can be the list so it repeates simulations for given SNR list.
	'''
	
	def __init__(self, chSbr, eqSbr, snrList, originData, mod, chOutList=None, flagN=1, stateGen=False):
		'''
		Description:
			Normalize chSbr to generate channel output.
			Generate channel output with input data for given snrList. Channel output seq can be multiple.
			Declare eq(class) to call equalizers.
		Params:
			chSbr(float list)			: channel SBR for internal generation of channel output.
			eqSbr(float list)			: SBR covered by equalizer.
			i.e) If chSbr == eqSbr, equalizer calculate exact ISI from channel. On the other hands, if chSbr != eqSbr, equalizer accuracy is worse.
			snrList(float list)			: Generating channel output and equalization is performed with this snrList.
			originData(float list)		: Used as channel input data.
			mod(str)					: Modulation.
			chOutList(float list, Optional)	: If specified, the class use this instead of internally generated channel output.
			flagN(int, Optional)		: If 0, the channel in the class doesn't add noise at all. If 1, noise is added.
			stateGen(bool, Optional)	: If true, it generates states for viterbi and fwdBwd. Flase for reduce __init__ sim time.
		'''
		self.chSbr_ = np.array(chSbr)
		self.chSbr_ /= np.linalg.norm(self.chSbr_)
		self.snrList_ = snrList
		self.chList_ = []
		self.originData = originData 
		self.mod_ = mod
		if chOutList is not None:
			self.chOutList_ = chOutList
		else:
			self.chOutList_ = []
			for snr in self.snrList_:
				self.chList_.append(Channel(sbr=self.chSbr_, snr=snr))
				self.chOutList_.append(self.chList_[-1].run(chIn=self.originData, flagN=flagN))
		self.eq = eq(estTap=eqSbr, mod=self.mod_, stateGen=stateGen)


	def fir(self, ffeTapNum=3):
		'''
		Description:
			Call predefined fir from eq(class) for given ffeTapNum.
			After performing fir, it checks BER by calling berChecker in eq(class).
			Print BER for each snr.
		Params:
			ffeTapNum(int)	: number of ffe tap
		'''
			
		print ("---------")
		print ("FIR start")
		print ("---------")
		berList = []
		for k in range(len(self.chOutList_)):
			#task.append(Thread(target=self.eq.fir, args=(self.originData, ffeTapNum)))
			chOut = self.chOutList_[k]
			if 0:
				plt.plot(chOut,'-x')
			(seq, delay) = self.eq.fir(chOut, ffeTapNum)
			(ber, be) = self.eq.berChecker(self.originData, seq, delay=delay,offsetStart=len(self.chSbr_),offsetEnd=len(self.chSbr_))
			print ("SNR : ", self.snrList_[k], "ber: ", ber, "be: ", be, flush=True)
			berList.append(ber)
		print ("FIR BER list")
		print (berList)
		return berList

	def dfe(self, dfeTapNum=5):
		'''
		Description:
			Call predefined dfe from eq(class) for given dfeTapNum.
			After performing dfe, it checks BER by calling berChecker in eq(class).
			Print BER for each snr.
		Params:
			dfeTapNum(int)	: number of dfe tap
		'''
	
		print ("---------")
		print ("DFE start")
		print ("---------")
		berList = []
		for k in range(len(self.chOutList_)):
			chOut = self.chOutList_[k]
			(dfeOutReal, dfeOut, delay) = self.eq.dfe(chOut,dfeTapNum)
			(ber, be) = self.eq.berChecker(self.originData,dfeOut,delay=delay,offsetStart=len(self.chSbr_),offsetEnd=len(self.chSbr_))
			print ("SNR : ", self.snrList_[k], "ber: ", ber, "be: ", be,flush=True)
			berList.append(ber)
		print ("DFE BER LIST")
		print (berList)
		return berList

######## Experiment
	def firDfe(self, ffeTapNum=None, ffeMaxTapNum=None, dfeTapNum=None):
		'''
		Description:
			Call predefined firDfe from eq(class) for given ffeTapNum, ffeMaxTapNum, dfeTapNum
			After performing firDfe, it checks BER by calling berChecker in eq(class)
			Print BER for each snr
		Params:
			ffeTapNum(int)	: number of ffe tap
			ffeMaxTapNum(int) : index for the maximum tap of ffe
			dfeTapNum(int)	: number of dfe tap
		'''
	
		print ("---------")
		print ("FIRDFE start")
		print ("---------")
		berList = []
		for k in range(len(self.chOutList_)):
			chOut = self.chOutList_[k]
			(dfeOutReal, dfeOut, delay) = self.eq.firDfe(chOut, ffeTapNum=ffeTapNum, maxTapNum=ffeMaxTapNum, dfeTapNum=dfeTapNum)
			(ber, be) = self.eq.berChecker(self.originData,dfeOut,delay=delay,offsetStart=len(self.chSbr_),offsetEnd=len(self.chSbr_),log=False)
			print ("SNR : ", self.snrList_[k], "ber: ", ber, "be: ", be,flush=True)
			berList.append(ber)
		print ("FIRDFE BER LIST")
		print (berList)
		return berList

	def firDfeSearchMaxTap(self, ffeTapNum, dfeTapNum=None):

		print("----------")
		print("Search MaxTap for FIRDFE...")
		print("----------", flush=True)

		berList = np.ones(ffeTapNum)
		for ffeMaxTapNum in range(ffeTapNum):
			print(f"{ffeMaxTapNum}..", end=" ", flush=True)
			(dfeOutReal, dfeOut, delay) = self.eq.firDfe(self.chOutList_[int(len(self.chOutList_)*3./4)], ffeTapNum=ffeTapNum, maxTapNum=ffeMaxTapNum, dfeTapNum=dfeTapNum)
			(ber, be) = self.eq.berChecker(self.originData,dfeOut,delay=delay,offsetStart=len(self.chSbr_),offsetEnd=len(self.chSbr_),log=False)
			berList[ffeMaxTapNum] = ber
		minIdx = list(berList).index(min(berList))
		print(f"The best ffeMaxTapNum: {minIdx}")
		return minIdx	# which is the best ffeMaxTapNum

		



######### Till here

	#def viterbi(self, blockSizeList=None):
	#	blockSizeList = blockSizeList if blockSizeList is not None else [100]
	#	print ("---------")
	#	print ("VITERBI start")
	#	print ("---------")
	#	berList = []
	#	berBlockList = []
	#	for k in range(len(self.chOutList_)):
	#		berBlockList = []
	#		for blockSize in blockSizeList:
	#			chOut = self.chOutList_[k]
	#			(score, mlSeq, V) = self.eq.viterbiPack(chOut,blockSize)
	#			(ber, be) = self.eq.berChecker(self.originData,mlSeq,offsetStart=0,offsetEnd=5)
	#			print ("SNR : ", self.snrList_[k], "ber: ", ber, "be: ", be, "blockSize: ", blockSize)
	#			berBlockList.append(ber)
	#			berList.append(ber)
	#		print (berBlockList)
	#	print ("VITERBI BER LIST")
	#	print (berList)

	def viterbiOverlap(self, blockSizeList=None):
		'''
		Description:
			Call predefined viterbiOverlap from eq(class) for given blockSizeList.
			It runs for all blockSizeList to compare performance of each.
			After performing, it checks BER by calling berChecker in eq(class).
			Print BER for each snr and each blockSize.
		Params:
			blockSizeList(int list)	: Size of a unit block for viterbi calculation .
		'''
	
		blockSizeList = blockSizeList if blockSizeList is not None else [100]
		print ("---------")
		print ("VITERBI start")
		print ("---------")
		berList = []
		berBlockList = []
		for k in range(len(self.chOutList_)):
			berBlockList = []
			for blockSize in blockSizeList:
				chOut = self.chOutList_[k]
				(score, mlSeq, V) = self.eq.viterbiOverlapPack(chOut,blockSize)
				(ber, be) = self.eq.berChecker(self.originData,mlSeq,offsetStart=len(self.chSbr_),offsetEnd=len(self.chSbr_))
				print ("SNR : ", self.snrList_[k], "ber: ", ber, "be: ", be, "blockSize: ", blockSize,flush=True)
				berBlockList.append(ber)
				berList.append(ber)
			print (berBlockList)
		print ("VITERBI BER LIST")
		print (berList)
		return berList

	def fwdBwd(self,fwdBwdLen, snrOvrd=None):
		'''
		Description:
			Call predefined fwdBwd from eq(class) for given fwdBwdLen.
			After performing, it checks BER by calling berChecker in eq(class).
			Print BER for each snr.
		Params:
			fwdBwdLen(int)	: Size of a unit block for fwdBwd calculation .
		'''
	
		print ("---------")
		print ("Forward-Backward start")
		print ("---------")
		berList = []
		berBlockList = []
		for k in range(len(self.chOutList_)):
			chOut = self.chOutList_[k]
			#print (f'chOut = {chOut}')
			if snrOvrd is not None:
				(seq, prob) = self.eq.fwdBwd(chOut,fwdBwdLen,snr=snrOvrd)
			else:
				(seq, prob) = self.eq.fwdBwd(chOut,fwdBwdLen,snr=self.snrList_[k])
			(ber, be) = self.eq.berChecker(self.originData,seq,delay=0,offsetStart=5,offsetEnd=5)
			print ("SNR : ", self.snrList_[k], "ber: ", ber, "be: ", be, flush=True)
			berList.append(ber)
		print ("FwdBwd BER LIST")
		print (berList)
		return berList, prob

	def fwd(self,fwdLen):
		'''
		Description:
			Call predefined fwd from eq(class) for given fwdLen.
			After performing, it check BER by calling berChecker in eq(class).
			Print BER for each snr.
		Params:
			fwdLen(int)	: Size of a unit block for fwd calculation .
		'''
		print ("---------")
		print ("Forward start")
		print ("---------")
		berList = []
		berBlockList = []
		for k in range(len(self.chOutList_)):
			chOut = self.chOutList_[k]
			(seq, prob) = self.eq.fwd(chOut,fwdLen,snr=self.snrList_[k])
			(ber, be) = self.eq.berChecker(self.originData,seq,delay=0,offsetStart=len(self.chSbr_),offsetEnd=len(self.chSbr_))
			print ("SNR : ", self.snrList_[k], "ber: ", ber, "be: ", be,flush=True)
			berList.append(ber)
		print ("FwdBER LIST")
		print (berList)
		return berList, prob

	def nnFwdBwd(self, neuralEQ, lossFn, batchSize, inSize, outSize, delay):
		'''
		Description:
			Call predefined fwd from simNeuralEQ(class) with parameters.
			Ber checker is automatically performed with running command(internallly implemented on simNeuralEQ(class)).
			Print BER for each snr.
		Params:
			neuralEQ(class)	: Pre-defined neural network class for nEQ.
			lossFn(str)	: loss function (mse, crossEntropy, manualCrossEntropy)
			batchSize(int) : Mini-batch size
			inSize(int)	: Size of neural network input
			outSize(int) : Size of neural network output
			delay(int) : Index of target data for inference.
		'''

		print ("---------")
		print ("nnFwdBwd start")
		print ("---------")
		berList = []
		simNEQ = simNeuralEQ(txDataTrain=[], rxDataTrain=[], txDataTest=[], rxDataTest=[], neuralEQ=neuralEQ, mod=self.mod_)
		for k in range(len(self.chOutList_)):
			chOut = self.chOutList_[k]
			testLoss, berTest = simNEQ.evalNeuralEQ(lossFn, batchSize=batchSize, inSize=inSize, outSize=outSize, delay=delay, rxDataTestNew=chOut, txDataTestNew=self.originData)
			print ("SNR : ", self.snrList_[k], "ber: ", berTest, flush=True)
			berList.append(berTest)
		print ("nnFwdBwd BER LIST")
		print (berList)
		return berList

		
